#!/usr/bin/env python

import getpass
import json
import logging
import pathlib
import sys
from datetime import datetime, timedelta, time
from logging import handlers
from typing import Tuple, Union
import pytz
import click
import pandas as pd
import requests
from absl import flags, app
from easydict import EasyDict
from futu import TrdEnv, TradeDealHandlerBase, RTDataHandlerBase, TrdSide, \
    OrderType, TrdMarket, KLType, OpenQuoteContext, OpenSecTradeContext, TradeOrderHandlerBase, RET_OK, \
    SubType, SecurityFirm, SysConfig
from ruamel import yaml

from quant import strategy, utils

flags.DEFINE_string('code', None, 'Trading code of target. e.g. HK.00700')
flags.DEFINE_string('config', None, 'Config name')
flags.mark_flag_as_required('code')
flags.mark_flag_as_required('config')

FLAGS = flags.FLAGS
TZ = pytz.timezone('Asia/Hong_Kong')


class StrategyContext(object):
    def __init__(self, st: strategy.GridStrategy, cfg: EasyDict):
        self.cfg = cfg
        self.signal_count = 0
        self.last_deal: Union[pd.Series, None] = None
        self.last_deal_signal: Union[pd.Series, None] = None
        self.prev_signal: Union[pd.Series, None] = None
        self.quote_ctx: Union[OpenQuoteContext, None] = None
        self.trade_ctx: Union[OpenSecTradeContext, None] = None
        self.strategy = st
        self.logger = None

    def start_trade(self):
        self.quote_ctx = OpenQuoteContext(host=self.cfg.host,
                                          port=self.cfg.port)  # 行情对象
        if self.cfg.backtest is None:
            self.trade()
        else:
            self.backtest()
            self.quote_ctx.close()

    def backtest(self):
        logging.info('回测开始日期: %s, 回测结束日期: %s' %
                     (self.cfg.backtest.start, self.cfg.backtest.end))
        bars = []
        ret, data, page_req_key = self.quote_ctx.request_history_kline(
            self.strategy.order_book_id,
            start=self.cfg.backtest.start,
            end=self.cfg.backtest.end,
            ktype=KLType.K_1M,
            max_count=1000)
        if ret == RET_OK:
            bars.append(data)
        else:
            return 'request_history_kline error: %s' % data
        while page_req_key is not None:  # 请求后面的所有结果
            ret, data, page_req_key = self.quote_ctx.request_history_kline(
                self.strategy.order_book_id,
                start=self.cfg.backtest.start,
                end=self.cfg.backtest.end,
                ktype=KLType.K_1M,
                max_count=1000,
                page_req_key=page_req_key)
            if ret == RET_OK:
                bars.append(data)
            else:
                return 'request_history_kline error: %s' % data
        data = pd.concat(bars)
        data.rename(columns={
            'code': 'order_book_id',
            'time_key': 'datetime'
        },
            inplace=True,
            errors='raise')
        data['datetime'] = pd.to_datetime(data.datetime)
        data.set_index(['datetime', 'order_book_id'], inplace=True)
        signals = []
        dt = None
        for index, row in data.iterrows():
            ndt = row.name[1]
            if dt != ndt:
                logging.info(ndt)
                dt = ndt
            s = self.strategy.generate_signal(data.loc[[index]], dt)
            signals.append(s)
        df = pd.concat(signals)
        fig = utils.plot_trade_signals(df.reset_index(), 'datetime', 'open',
                                       '1m', self.strategy.grid)
        fig.savefig(pathlib.Path(__file__).parent.joinpath('grid-futu.png'))
        df.to_csv(pathlib.Path(__file__).parent.joinpath('grid-futu.csv'),
                  index=False)

    def trade(self):
        trade_log = '{0}/{1}.log'.format(self.cfg.logdir,
                                         self.strategy.description)
        fhandler = handlers.TimedRotatingFileHandler(trade_log,
                                                     when='D',
                                                     backupCount=3)
        fhandler.suffix = '%Y%m%d.log'
        fhandler.setLevel(logging.INFO)
        logging.getLogger().addHandler(fhandler)
        logging.getLogger().setLevel(logging.INFO)

        self.logger = logging.getLogger('signal_logger')
        signal_log = '{0}/{1}.csv'.format(self.cfg.logdir,
                                          self.strategy.description)
        handler = handlers.TimedRotatingFileHandler(signal_log,
                                                    when='D',
                                                    backupCount=3)
        handler.suffix = '%Y%m%d.csv'
        handler.setLevel(logging.INFO)
        self.logger.addHandler(handler)
        self.logger.setLevel(logging.INFO)
        self.logger.propagate = False

        self.trade_ctx = OpenSecTradeContext(
            filter_trdmarket=TrdMarket.HK,
            host=self.cfg.host,
            port=self.cfg.port,
            security_firm=SecurityFirm.FUTUSECURITIES)  # 交易对象，根据交易品种修改交易对象类型
        logging.info('************  策略开始运行 ***********')

        if self.cfg.real:
            # 解锁交易（如果是模拟交易则不需要解锁）
            self.cfg.trd_env = TrdEnv.REAL
        else:
            self.cfg.trd_env = TrdEnv.SIMULATE

        if self.cfg.continued:
            # Always use REAL env when querying history deals
            # No history deal data available for SIMULATE env
            ret, data = self.trade_ctx.history_deal_list_query(
                code=self.strategy.order_book_id, trd_env=TrdEnv.REAL)
            if ret != RET_OK:
                self.notify_signal('history_deal_list_query error: %s' % data,
                                   fatal=True)
            else:
                if data.shape[0] == 0:  # 如果成交列表为空
                    self.notify_signal('无法找到历史成交', fatal=True)
                data.sort_values('create_time', ascending=False, inplace=True)
            self.last_deal = data.iloc[0]
            self.last_deal_signal = pd.read_csv(
                pathlib.Path(self.cfg.logdir).joinpath(
                    'last_deal_%s.csv' % self.strategy.order_book_id)).iloc[0]
            self.last_deal_signal.price = self.last_deal.price
            self.strategy.init_state(
                self.last_deal_signal,
                self.strategy.grid_lots(self.last_deal_signal.grid_id))
            position = round(
                1 -
                float(self.last_deal_signal.grid_id) / len(self.strategy.grid),
                2)
            logging.info('重启网格，网格区间: %s' % self.strategy.grid)
            logging.info(
                '当前网格: {}, 仓位: {:.2%}, 网格持仓:{}手, 总持仓{}手, 前一订单: {} at {}'.
                format(
                    self.last_deal_signal.grid_id, position,
                    self.strategy.lots, self.get_holding_position(),
                    'BUY' if self.last_deal_signal.signal
                    == strategy.TradeOp.BUY else 'SELL',
                    self.last_deal_signal.price))

        else:
            logging.info('启动新网格，网格区间: %s' % self.strategy.grid)

        end = datetime.now(TZ)
        ret, data, _ = self.quote_ctx.request_history_kline(
            self.strategy.order_book_id,
            start=(end - timedelta(days=5)).strftime('%Y-%m-%d'),
            end=end.strftime('%Y-%m-%d'),
            ktype=KLType.K_1M,
            max_count=2000)
        if ret != RET_OK:
            self.notify_signal('request_history_kline error: %s' % data,
                               fatal=True)
        data.rename(columns={
            'code': 'order_book_id',
            'time_key': 'datetime',
        },
            inplace=True,
            errors='raise')
        data['datetime'] = pd.to_datetime(data.datetime)
        data.set_index(['datetime', 'order_book_id'], inplace=True)
        s = self.strategy.generate_signal(data.tail(1), end).iloc[0]
        # Import! need to reinit state
        self.strategy.init_state(
            self.last_deal_signal,
            self.strategy.grid_lots(self.last_deal_signal.grid_id))
        logging.info('前一k线交易信号: %s' % s.op)

        if click.confirm('Continue?'):
            if self.cfg.real:
                # 解锁交易（如果是模拟交易则不需要解锁）
                try:
                    password = getpass.getpass()
                except Exception as error:
                    logging.error('ERROR: ', error)
                    raise error
                ret, data = self.trade_ctx.unlock_trade(password)
                if ret != RET_OK:
                    self.trade_ctx.close()
                    self.notify_signal('解锁交易失败:%s' % data, fatal=True)
                self.notify_signal('解锁交易成功！实盘网格开始运行。')
            else:
                logging.info('模拟交易模式！网格开始运行')

            self.strategy.order_by_signal = self.adjust_position
            # 设置回调
            self.quote_ctx.set_handler(OnRTClass(self))
            self.trade_ctx.set_handler(OnOrderClass(self))
            self.trade_ctx.set_handler(OnFillClass(self))
            # 订阅标的合约的实时报价和摆盘，以便获取数据
            self.quote_ctx.subscribe(
                code_list=[self.strategy.order_book_id],
                subtype_list=[SubType.ORDER_BOOK, self.cfg.subtype])
        else:
            logging.info('网格退出')
            self.quote_ctx.close()
            self.trade_ctx.close()

    def notify_every_n(self, msg: str, n: int):
        self.signal_count += 1
        if self.signal_count % n == 0:
            self.notify_signal(msg)
        else:
            logging.info(msg)

    # Send signal notification to WeChat and pushdeer
    def notify_signal(self, msg: str, fatal: bool = False):
        if self.cfg.notify is not None:
            if self.cfg.notify.wechat is not None:
                data = {'msg': msg}
                headers = {
                    'Content-type': 'application/json',
                    'Accept': 'text/plain'
                }
                proxies = {
                    'http': '',
                }
                try:
                    requests.post(self.cfg.notify.wechat,
                                  data=json.dumps(data),
                                  headers=headers,
                                  proxies=proxies)
                except Exception as e:
                    logging.error(e)

            if self.cfg.notify.pushdeer is not None:
                url = 'https://api2.pushdeer.com/message/push?pushkey=%s&text=%s' % (
                    self.cfg.notify.pushdeer, msg)
                proxies = {
                    'http': '',
                    'https': '',
                }
                try:
                    requests.get(url, proxies=proxies)
                except Exception as e:
                    logging.error(e)

            if self.cfg.notify.corpwechat is not None:
                useridstr = 'mmBa'
                agentid = '1000002'
                corpid = 'wwe72114be147b892c'
                corpsecret = self.cfg.notify.corpwechat
                response = requests.get(
                    f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={corpid}&corpsecret={corpsecret}'
                )
                data = json.loads(response.text)
                access_token = data['access_token']

                json_dict = {
                    "touser": useridstr,
                    "msgtype": "text",
                    "agentid": agentid,
                    "text": {
                        "content": msg
                    },
                    "safe": 0,
                    "enable_id_trans": 0,
                    "enable_duplicate_check": 0,
                    "duplicate_check_interval": 1800
                }
                json_str = json.dumps(json_dict)
                try:
                    response_send = requests.post(
                        f'https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}',
                        data=json_str)
                    resp = json.loads(response_send.text)
                    if resp['errmsg'] != 'ok':
                        logging.error(resp['errmsg'])
                except Exception as e:
                    logging.error(e)

        if fatal:
            logging.fatal(msg)
            self.quote_ctx.close()
            self.trade_ctx.close()
            sys.exit(1)
        else:
            logging.info(msg)
        self.signal_count = 0

    # 获取持仓数量(手)
    def get_holding_position(self) -> int:
        ret, data = self.trade_ctx.position_list_query(
            code=self.strategy.order_book_id, trd_env=self.cfg.trd_env)
        if ret != RET_OK:
            holding_position = 0
            self.notify_signal('获取持仓数据失败:%s' % data, fatal=True)
        else:
            if data.shape[0] == 0:  # 如果成交列表为空
                holding_position = 0
            else:
                holding_position = data['qty'][0]
        return int(holding_position / self.cfg.round_lot)

    # 获取一档摆盘的 ask1 和 bid1（即买一和卖一）
    def get_ask_and_bid(self) -> Tuple[float, float]:
        ret, data = self.quote_ctx.get_order_book(self.strategy.order_book_id,
                                                  num=1)
        if ret != RET_OK:
            self.notify_signal('获取摆盘数据失败:%s' % data, fatal=True)
        return data['Ask'][0][0], data['Bid'][0][0]

    # 调仓函数
    def adjust_position(self, signal: pd.Series):
        hands = self.strategy.grid_lots(signal.grid_id) - self.strategy.lots
        ask, bid = self.get_ask_and_bid()
        qty = abs(hands) * self.cfg.round_lot
        if self.cfg.trd_env == TrdEnv.REAL:
            order_type = OrderType.MARKET
        else:
            order_type = OrderType.NORMAL
        if hands > 0:
            trd_side = TrdSide.BUY
            price = ask
            op = '买入'
            #  max_qty = self.get_max_quantity(price, trd_side)
            #  if max_qty == 0 or qty > max_qty:
            #  self.notify_signal('下单数量超出最大可买数量: %s vs %s。' % (qty, max_qty),
            #  fatal=True)
        else:
            trd_side = TrdSide.SELL
            price = bid
            op = '卖出'
            #  max_qty = self.get_max_quantity(price, trd_side)
            #  if max_qty == 0 or qty > max_qty:
            #  self.notify_signal('下单数量超出最大可卖数量: %s vs %s。' % (qty, max_qty),
            #  fatal=True)

        self.notify_signal('下单:%s%s股%s' %
                           (op, qty, self.strategy.order_book_id))
        ret, data = self.trade_ctx.place_order(
            price=price,
            qty=qty,
            code=self.strategy.order_book_id,
            trd_side=trd_side,
            order_type=order_type,
            trd_env=self.cfg.trd_env,
            remark=self.strategy.description)
        if ret != RET_OK:
            self.notify_signal('下单失败:%s' % data, fatal=True)

    def get_max_quantity(self, price: float, trd_side: TrdSide) -> int:
        # Use NORMAL instead of MARKET here as max_cach_buy can be smaller than
        # actual cash can buy with current price
        ret, data = self.trade_ctx.acctradinginfo_query(
            order_type=OrderType.NORMAL,
            code=self.strategy.order_book_id,
            price=price,
            trd_env=self.cfg.trd_env)
        max_can_buy = data['max_cash_buy'][0]
        max_can_sell = data['max_position_sell'][0]
        if ret != RET_OK:
            self.notify_signal('获取最大可买可卖失败:%s' % data, fatal=True)
        else:
            self.notify_signal('最大可买: %s, 最大可卖: %s' %
                               (max_can_buy, max_can_sell))
        if trd_side == TrdSide.BUY:
            return max_can_buy
        else:
            return max_can_sell


class OnRTClass(RTDataHandlerBase):
    def __init__(self, strategy_ctx: StrategyContext):
        super().__init__()
        self.strategy_ctx = strategy_ctx

    def on_recv_rsp(self, rsp_pb):
        now = datetime.now(TZ)
        if now < TZ.localize(datetime.combine(now.date(), time(
                9, 30, 0))) or now > TZ.localize(
                    datetime.combine(now.date(), time(16, 0, 0))):
            logging.info('%s: 非交易时段' % now)
            return
        ret_code, data = super().on_recv_rsp(rsp_pb)
        if ret_code != RET_OK:
            self.strategy_ctx.notify_signal('RT error: %s' % data, fatal=True)
        else:
            data['datetime'] = pd.to_datetime(data.time)
            for field in ['open', 'close', 'low', 'high']:
                data[field] = data['cur_price']
            row = data.iloc[0]
            self.strategy_ctx.now = row.datetime
            self.strategy_ctx.strategy.handle_bar(self.strategy_ctx,
                                                  {row.code: row})
            s = self.strategy_ctx.strategy.signal
            if self.strategy_ctx.last_deal is not None:
                last_deal = self.strategy_ctx.last_deal
                last_op = '%s %s %s at %s' % (last_deal.create_time,
                                              last_deal.trd_side,
                                              last_deal.qty, last_deal.price)
            else:
                last_op = 'N/A'
            msg = '{}: 交易信号:{}，仓位:{:.2%}，最近交易:{}'.format(
                row.datetime, s.op,
                1 - float(s.grid_id) / len(self.strategy_ctx.strategy.grid),
                last_op)
            self.strategy_ctx.logger.info(
                s.to_frame().transpose().to_csv(header=False).rstrip())
            if self.strategy_ctx.prev_signal is None or s.signal != self.strategy_ctx.prev_signal.signal:
                self.strategy_ctx.notify_signal(msg)
            else:
                # roughly 1 rt data per second set to 10 minutes or so (10 * 60)
                self.strategy_ctx.notify_every_n(msg, n=600)
            self.strategy_ctx.prev_signal = s
            if s.signal in [
                    strategy.TradeOp.OPEN, strategy.TradeOp.BUY,
                    strategy.TradeOp.CLOSE, strategy.TradeOp.SELL
            ]:
                s['signal'] = s.signal.astype(int)
                s.to_frame().transpose().to_csv(
                    pathlib.Path(self.strategy_ctx.cfg.logdir).joinpath(
                        'last_deal_%s.csv' %
                        self.strategy_ctx.strategy.order_book_id))
            # state of grid will be init on deal status update


class OnOrderClass(TradeOrderHandlerBase):
    def __init__(self, strategy_ctx: StrategyContext):
        super().__init__()
        self.strategy_ctx = strategy_ctx

    def on_recv_rsp(self, rsp_pb):
        ret, data = super(OnOrderClass, self).on_recv_rsp(rsp_pb)
        if ret != RET_OK:
            self.strategy_ctx.notify_signal('Order error: %s' % data,
                                            fatal=True)
        else:
            order = data.iloc[0]
            if order.trd_env == self.strategy_ctx.cfg.trd_env:
                self.strategy_ctx.notify_signal(
                    '%s: 【订单状态】%s, %s %s %s %s %s' %
                    (order.updated_time, order.order_status, order.stock_name,
                     order.trd_side, order.order_type, order.price, order.qty))
            else:
                logging.info('Skipping order from different env')


class OnFillClass(TradeDealHandlerBase):
    def __init__(self, strategy_ctx: StrategyContext):
        super().__init__()
        self.strategy_ctx = strategy_ctx

    def on_recv_rsp(self, rsp_pb):
        ret, data = super(OnFillClass, self).on_recv_rsp(rsp_pb)
        if ret != RET_OK:
            self.strategy_ctx.notify_signal('Fill error: %s' % data,
                                            fatal=True)
        else:
            deal = data.iloc[0]
            # Deal from other trades will be received as well
            if deal.code == self.strategy_ctx.strategy.order_book_id:
                holding_position = self.strategy_ctx.get_holding_position()
                self.strategy_ctx.notify_signal(
                    '%s: 【成交状态】%s, %s %s %s %s, 持仓%s手' %
                    (deal.create_time, deal.status, deal.stock_name,
                     deal.trd_side, deal.price, deal.qty, holding_position))
                signal = self.strategy_ctx.strategy.signal
                signal['price'] = deal.price
                self.strategy_ctx.strategy.init_state(signal, holding_position)
                self.strategy_ctx.last_deal = deal


def main(_):
    with pathlib.Path(__file__).parent.joinpath(
            '%s.yml' % FLAGS.code).open(mode='r') as f:
        doc = yaml.load(f, Loader=yaml.UnsafeLoader)
    cfg = EasyDict(doc[FLAGS.config])
    if cfg.secret is not None:
        SysConfig.enable_proto_encrypt(is_encrypt=True)
        SysConfig.set_init_rsa_file(pathlib.Path(cfg.secret))
    grid_strategy = getattr(strategy, cfg.strategy)(**cfg.params)
    if cfg.freq == '1m':
        cfg.subtype = SubType.K_1M
    elif cfg.freq == 'rt':
        cfg.subtype = SubType.RT_DATA
    cfg.logdir = pathlib.Path(cfg.logdir).expanduser()
    ctx = StrategyContext(grid_strategy, cfg)
    ctx.start_trade()


# 主函数
if __name__ == '__main__':
    app.run(main)
